Skip to content

Implement get_unit_spike_trains and performance improvements#4502

Open
alejoe91 wants to merge 24 commits intoSpikeInterface:mainfrom
alejoe91:get-unit-spike-trains
Open

Implement get_unit_spike_trains and performance improvements#4502
alejoe91 wants to merge 24 commits intoSpikeInterface:mainfrom
alejoe91:get-unit-spike-trains

Conversation

@alejoe91
Copy link
Copy Markdown
Member

@alejoe91 alejoe91 commented Apr 9, 2026

  • expose and propagate use_cache (to get_unit_spike_train_in_seconds)
  • fix wrong check in to_reordered_spike_vector
  • avoid lexsort when not needed in select_units

@grahamfindlay

TODO

  • Implement numpy/numba get_unit_spike_trains for PhyKilosortSortingExtractor

(maybe in follow up)

@alejoe91 alejoe91 requested a review from chrishalcrow April 9, 2026 15:41
@alejoe91 alejoe91 added core Changes to core module performance Performance issues/improvements labels Apr 9, 2026
alejoe91 and others added 8 commits April 9, 2026 17:58
- Drop unused `return_times` parameter from get_unit_spike_trains_in_seconds
- Clean up stale/truncated docstrings on get_unit_spike_train_in_seconds,
  get_unit_spike_trains, and get_unit_spike_trains_in_seconds
- Fix UnitsSelectionSortingSegment.get_unit_spike_trains to re-key the
  returned dict with child unit ids (was returning parent-keyed dict,
  breaking whenever renamed_unit_ids differ from parent ids)
- Fix test_get_unit_spike_trains: drop unused return_times kwarg, remove
  unused local variable, fix assertion.
The previous check `np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0`
was never `True`, because `ids_to_indices(self._renamed_unit_ids)` on a USS
always returns `[0, 1, ..., k-1]` (since `_main_ids == _renamed_unit_ids`), so the
diff was always positive and the lexsort branch was unreachable. Therefore the
cached spike vector was wrong whenever two units had co-temporal spikes and the
selection reordered them relative to the parent.

Replaced with a two-step check that attempt to avoid unneccessary lexsorts:
  1. O(k) `_is_order_preserving_selection()` -- Checks if USS `._unit_ids` is
     in the same relative order as in the parent. When True, the remapped vector
     is guaranteed sorted (boolean filtering preserves order; the remap only
     relabels unit_index values). This is the common case via `select_units()`
     with a boolean mask.
  2. O(n) `_is_spike_vector_sorted()` -- Checks if the remapped vector is still
     sorted by (segment, sample, unit). Catches the case where the selection is
     not order-preserving but no co-temporal (same exact sample) spikes exist.

Falls back to the original O(n log n) lexsort only when both checks fail.
`BaseSorting` builds the spike vector with a per-unit boolean scan
over spike_clusters, which is (O(N*K)).

If we already have the full flat spike time and spike cluster arrays, we can
do a lot better by building the spike vector in one shot.
(I think O(N log N) from the lexsort, which is also pessimistic,
because the lexsort doesn't always need to happen.
Under any circumstances I can dream of, K >> log N.)

Since Phy/Kilosort segments already load the full flat arrays when the
`PhyKilosortSorting` object is created, and keep them around  as
`._all_spikes` and `._all_clusters`, we can just use those! :)

Also populates `_cached_spike_vector_segment_slices` directly, so
that `BaseSorting`'s `_get_spike_vector_segment_slices()` lazy
recomputation is skipped.
`BaseSortingSegment.get_unit_spike_trains()` loops over
`get_unit_spike_train`, which is O(N*K) because each call is a
boolean scan over _all_clusters/_all_spikes.

If we know we are going to be getting all the trains, we can do it
much faster. And if we can use numba, even faster still.

In fact, even if we only want _some_ spike trains, it is still often
faster to get all the trains and just discard the ones we don't need,
than to get only the trains we need do unit-by-unit (because we
only ever store or cache flat arrays of spike times/clusters).

Note that **only the use_cache=False path is affected**; the
use_cache=True triggers the computation of the spike vector, which
I don't think can ever be the most efficient way to get spike trains.
…izations

- Fixed test_compute_and_cache_spike_vector: was comparing an array to
  itself (to_spike_vector use_cache=False still returns the cached
  vector). Now explicitly calls the USS override and the BaseSorting
  implementation, and compares the two.
- Added test_uss_get_unit_spike_trains_with_renamed_ids: also not a test
  of the optimization commits per se, but would have caught a mistake made
  along the way. Verifies get_unit_spike_trains returns child-keyed dicts
  (not parent-keyed).
- Added test_spike_vector_sorted_after_reorder_with_cotemporal_spikes:
  verifies the USS spike vector is correctly sorted when the selection
  reverses unit order and co-temporal spikes exist.
- Added test_phy_sorting_segment_get_unit_spike_trains: validates the
  new fast methods on PhySortingSegment.
- Added test_phy_compute_and_cache_spike_vector: verifies the Phy
  override of _compute_and_cache_spike_vector matches BaseSorting
  implementation.
@grahamfindlay
Copy link
Copy Markdown
Contributor

@alejoe91 my changes PR'd to your fork whenever you're ready.

The only thing I should point out that isn't in the commit messages:
I mocked a minimal Phy folder for testing instead of using the phy_example_0 GIN dataset, just because it was quick, easy, and lightweight. I did feel a little guilty doing it, but I'm also not convinced it was a bad idea.

@alejoe91 alejoe91 marked this pull request as ready for review April 14, 2026 09:47
@h-mayorquin
Copy link
Copy Markdown
Collaborator

I am curios on what prompted this? What profiling did you guys do? Any chance that we have a discussion here on the repo at least to know what where the performance benchmarks, reason and validation.

@alejoe91
Copy link
Copy Markdown
Member Author

@grahamfindlay is doing very long chronic recordings. He does all the processing and at a second iteration wants to load the phy sorting object, select some units, and get all the spike trains.

Just caching the spike vector takes almost 4 minutes! Plus there were some additional lexsort that can be avoided and speed up computation.

At least to give some context @h-mayorquin

@grahamfindlay maybe you can add some more details on benchmarks and profiling?

@grahamfindlay
Copy link
Copy Markdown
Contributor

grahamfindlay commented Apr 16, 2026

Here are example timings for various operations using 1 example subject. This subject only has ~400 million spikes - I have some with many more. FWIW, you shouldn't need long chronic recordings to see tangible improvements from most of these changes. I must dig through notes but tested with 100M spikes and they were still clear gains.

"Parent before" = The KiloSortSortingExtractor (342 units), pre-PR
"Parent after" = The KiloSortSortingExtractor, with PR
"Leaf after" = Two layers of UnitSelectionSorting (first layer: 258 units, second layer: 258 units), pre-PR
"Leaf after" = UnitSelectionSorting, with PR

The two layers of UnitSelectionSorting come from 1) selecting based on quality, 2) selecting based on cell type (here I asked for all cell types, so should effectively be no-op, but actually it has a big cost).

Operation Parent Before Parent After Leaf Before Leaf After Notes
to_spike_vector() 5m58s 2m51s 5m58s + 4m30s 2m51s + 21s Time for parent + marginal time for children
precompute_spike_trains() +2m18s +2m10s +2m47s +1m54s Starting from a hot cache and precomputed parent spike trains (for leaf), i.e. best case
loop over get_unit_spike_train(use_cache=False) 3m40s 3m40s ~13m16s (bug) 3m15s Bugs: wasn't available when return_times=True; use_cache was never respected by UnitSelectionSorting
get_unit_spike_trains(use_cache=True) N/A Same as precompute_spike_trains() N/A Same as precompute_spike_trains() Just syntactic sugar; still relies on the spike vector
get_unit_spike_trains(use_cache=False) 3m40s 35s (numba) / 1m49s (numpy) 3m15s 35s (numba) / 1m49s (numpy) Numba / NumPy should be 11s / 1m15s; must fix

Comments:

  • The 4m30s to get the UnitSelectionSorting (USS) spike vector was 2m15s per layer, including the no-op layer.
  • The improvements in to_spike_vector() come from overriding the base class method to take advantage of the fact that the KS extractor already has access to the full flat arrays, is single-segment, and from checks to avoid needless lexsorting on the USS.
    • One of the things that can trigger lexsorting is if the user selects unit ids in a different relative order than they appeared in the parent. I handle this pathological case, but it might be worth discussing whether they should be allowed to do this in the first place, or whether we should always re-order the ids to match the parent. Hopefully it's uncommon in practice.
  • The fact that it still takes minutes to precompute the spike trains for a USS after precomputing the spike trains for the KiloSortSortingExtractor (ie all units) is conspicuous. You're asking for a subset of the trains you already precomputed -- it should be instant!
  • The best way to get spike trains before was to bypass the cache so you could bypass the spike vector computation, but because of bugs you couldn't do this from a USS, or if you wanted to return times in seconds (without accessing private properties like ._parent_segment).
  • The new get_spike_trains() takes advantage of the fact that if you know you will get all spike trains, you can again take advantage of the full flat (sorted) arrays and just scatter them. Basically, to get spike trains, you don't have to figure out if two spikes from different units occur on the same sample, which you do need to know for the spike vector. In fact, it's so much cheaper that even if you only want ~20 spike trains, it's better to just get them all and discard the ones you aren't interested in...
  • You could probably apply these same principles to get gains for other extractors besides the Phy/KiloSort ones.

There are rough edges with this PR I know about:

  1. Alessio and Samuel pointed out that the BasePhyKilosortSortingExtractor is never multi-segment, so I can remove a loop over segments and save a possibly expensive call to np.concatenate(). UPDATE 4/7: Fixed.
  2. When translating my prototype numba and numpy versions of get_spike_trains() to the production versions, I made some minor changes to fit function signatures and style that I thought would be harmless, but apparently they add ~20s to the numba implementation and 30s to the numpy one. This is pretty significant, as it took the numba path from 11s to 35s. I need to go back and figure out why the effect of these seemingly small changes was so dramatic. UPDATE 4/7: Solved, but not committed. See below.

Another question I haven't resolved:

  • Why is building the spike vector so costly the first place, even using my new more efficient method for 1-shot'ing it from the full flat arrays? It just feels like there should be a better way. As Samuel pointed out, it's 1 big malloc as opposed to K (K = units) mallocs for a dictionary of trains. I don't think I'm just creating views onto the underlying flat arrays... so what gives? Is it really all in the O(N log N) lexsort, and the need to allocate a lot of space for the segment indices? UPDATE 4/7: Probably solved, but not committed. See below.

Comment on lines +352 to +375
if use_cache:
# TODO: speed things up
ordered_spike_vector, slices = self.to_reordered_spike_vector(
lexsort=("sample_index", "segment_index", "unit_index"),
return_order=False,
return_slices=True,
)
unit_indices = self.ids_to_indices(unit_ids)
spike_trains = {}
for unit_index, unit_id in zip(unit_indices, unit_ids):
sl0, sl1 = slices[unit_index, segment_index, :]
spikes = ordered_spike_vector[sl0:sl1]
spike_frames = spikes["sample_index"]
if start_frame is not None:
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
spike_trains[unit_id] = spike_frames
else:
spike_trains = segment.get_unit_spike_trains(
unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is overkill, and should be replaced with something like

spike_trains = {unit_id: self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame, use_cache=use_cache) for unit_id in unit_ids}
return spike_trains

In my local testing, this gives the same speed results. The one thing gain is avoiding repeated time slicing, but I think it's a marginal gain for all this, fairly confusing, code. The spike train code is already fairly complex.

Happy to be proven wrong with benchmarking from a very long recording, but I don't think it's worth doing this until we need it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the use_cache=True path:

  • I see basically no difference in performance for a cold cache (because lexsort dominates)
  • I see ~10-20x benefit for a hot cache (2M spikes, 800 units), but it's only ~ 10ms vs 150ms, so negligible.
    I'm not surprised the gains are marginal - it's just repeated per-unit segment checks, id_to_index(), etc that is saved. So adding more spikes won't really change much. This path still relies on the spike vector, and that will always be slow (see the precompute_spike_times() entry in the table above).

However, for the use_cache=False path:

  • I think it is important that the use_cache=False path continue to dispatch to segment.get_unit_spike_trains() (emphasis on segment not self and trains not train), so that the segment can override the multi-unit path - that is where the biggest (eg the numba accelerated) gains in this PR (seconds vs minutes) come from.


def get_unit_spike_trains(
self,
unit_ids: np.ndarray | list,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd vote to get all spike trains if user doesn't pass unit_ids. Surely almost all user use cases for get_unit_spike_trains is to get all unit spike trains?

Copy link
Copy Markdown
Contributor

@grahamfindlay grahamfindlay Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm all for that. I actually think that unless we are going to cache the spike trains as such (rather than as a reordered spike vector) -- and I don't think we are [1] -- we should just call the function get_all_spike_trains() and return them all. That would most accurately reflect what the function does. It would make it less surprising for the user that getting 30 spike trains takes the same time as getting 300 spike trains. It would probably also encourage better access patterns. And they can easily filter the dict themselves with a 1-liner like:

spike_trains = {id: train for id, train in sorting.get_all_spike_trains() if id in unit_ids}

[1] I think caching both the spike trains and the spike vector would be bad, since the caches could drift out of sync with each other unless care were taken to avoid this, and syncing caches would presumably negate all the benefits of using one representation over another.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I have tracked down the regressions I mentioned above:

There are rough edges with this PR I know about:
...
When translating my prototype numba and numpy versions of get_spike_trains() to the production versions, I made some minor changes to fit function signatures and style that I thought would be harmless, but apparently they add ~20s to the numba implementation and 30s to the numpy one. This is pretty significant, as it took the numba path from 11s to 35s. I need to go back and figure out why the effect of these seemingly small changes was so dramatic.

They are all related to the handling of the unit_ids argument. I can address them while preserving the argument, I think... but turning get_unit_spike_trains(unit_ids, ...) into get_all_spike_trains() would make them all trivially disappear, and would simplify things quite a bit.

Comment on lines +420 to +421
start_frame=start_frame,
end_frame=end_frame,
Copy link
Copy Markdown
Member

@chrishalcrow chrishalcrow Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this code to use the times to figure out the start/end frames, and use them here. Instead, this code gets all spike trains then slices. Why?
(EDIT: I'm sure there is a good reason I've not thought of!!)

Copy link
Copy Markdown
Contributor

@grahamfindlay grahamfindlay Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also was confused by this at first, but I think it is because there's no guarantee that the sample returned by BaseRecording.time_to_sample_index() exactly corresponds to the time you give it (it is more like "last frame at or before") so it can behave weirdly when you use to get fetch bounds. For example, if a time vector has samples at [0.0, 0.1, 0.2] and you pass start_time=0.15 to get_unit_spike_trains_in_seconds(), time_to_sample_index(0.15) returns frame 1, but frame 1 has time 0.1 and should be excluded. @alejoe91 can confirm.

You do raise a good point, which is that it seems inefficient to scan the whole train, depending on the underlying representation, and in fact I did implement the bounded scan on PhyKilosortSortingExtractor.get_unit_spike_trains(). Maybe what could be done is, get some conservative frame bounds, use those to fetch the underlying trains, and then do a final mask on the result. Something like:

start_frame = None if start_time is None else first_frame_at_or_after(start_time)
end_frame = None if end_time is None else first_frame_at_or_after(end_time)

spike_frames = self.get_unit_spike_train(..., start_frame=start_frame, end_frame=end_frame)
spike_times = self.sample_index_to_time(spike_frames, ...)
spike_times = spike_times[spike_times >= start_time]
spike_times = spike_times[spike_times < end_time]

It seems plausible to me that this could save time.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some more exploration of this. Yes, it works. You can save a lot of time doing it (if you are bypassing the cache). There is 1 major issue. It is also a pre-existing issue (IMO) with all the other variants of get_unit_spike_train[s][_in_seconds](). The issue is: What should we do if all of the following are true?

  1. There is a cold cache.
  2. The user requests use_cache=True (the current default).
  3. The user requests a small fraction of the sorting, say 60s from a 24h sorting.

It does not make sense to cache only a fraction of the spike vector. But computing the whole spike vector is hella expensive, and you shouldn't have to do that to get 60s of data. The best thing to do is to NOT build the cache (not propagate use_cache=True). But this is not what get_unit_spike_train() and get_unit_spike_trains() do -- they build the whole cache even if you just want a few frames from 1 unit, if you pass use_cache=True.

The problem, to me, is that use_cache currently means both "compute whole-sorting cache if not present" and "use some fraction of whole-sorting cache". One thing we could do is have an "auto" option:

use_cache: bool | Literal["auto"] = "auto"

with

  • "auto": reuse existing cache; do not build full cache for bounded frame/time slice queries
  • True: build/reuse full cache exactly as requested
  • False: bypass cache

This would be a more sensible default, IMO, than use_cache=True. My personal preferred default would just be use_cache=False, so all the big slow spike vector ops are explicitly opt-in, but "auto" could be a compromise.

@grahamfindlay
Copy link
Copy Markdown
Contributor

Just brainstorming here: Having both get_unit_spike_train[s](*, ...return_times=True) and get_unit_spike_train[s]_in_seconds() seems a bit clunky. Both return unit spike trains in seconds. The real difference is that the latter takes start/end bounds as an argument, and gives native-timestamp extractors a direct path for input/output in seconds.

Could we do the following?

get_unit_spike_train(
    unit_id,
    segment_index=None,
    start_frame=None,
    end_frame=None,
    start_time=None,
    end_time=None,
    return_times=False,
    use_cache=True,
)

and then:

  • Passing both any frame bound and any time bound raises ValueError.
  • If start_time / end_time are supplied, bounds are interpreted as seconds regardless of return_times.
  • return_times controls output units only, not input-bound units.
  • Native timestamp extractors still get first priority when seconds output or seconds bounds are requested?

Seems to me like it could be less confusing for the user, with less code, and simpler logic.

@grahamfindlay
Copy link
Copy Markdown
Contributor

Following up on this:

Another question I haven't resolved:

Why is building the spike vector so costly the first place, even using my new more efficient method for 1-shot'ing it from the full flat arrays? It just feels like there should be a better way. As Samuel pointed out, it's 1 big malloc as opposed to K (K = units) mallocs for a dictionary of trains. I don't think I'm just creating views onto the underlying flat arrays... so what gives? Is it really all in the O(N log N) lexsort, and the need to allocate a lot of space for the segment indices?

@samuelgarcia I did some profiling and I think I may have cracked this. On the dataset I sent you:

  • id-to-index remap: 25.3 s
  • filling the structured array: 5.5 s
  • np.lexsort: 132.1 s
  • sorted gather copy: 15.0 s

So the question I had was, why is lexsorting so slow? It's single-segment (observation: we could be lexsorting the segments separately, then concatenating, instead of sorting all together) and the times are already sorted. Are there really so many co-temporal spikes? Answer: no. There are ~33M (out of 400M) spikes, and this is including all the ugliest clusters (that are most likely to have co-temporal spikes if eg they represent artifact). The maximum tie-group size is 12. So with a lookup table and tie-only, in-place reordering, it looks like I can build the full spike vector in about 11 s after load, plus 3.4 s for sortedness verification. Turns out numpy's lexsort is not so hot. There's still reordering to optimize, in order to get spike trains, but it looks like there might be hope yet for the spike vector. Disclaimer: more testing needed before I am willing to declare victory.

Also for @samuelgarcia / the connoisseur: Our assumption that getting spike trains as a dict means K mallocs was wrong - you can allocate one big array, much like the spike vector, put the sorted trains in it, then each dict value is a view onto that (not to be confused with a view onto ._all_spikes). So this is one reason the dict approach can be so fast and still beats the spike vector.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

core Changes to core module performance Performance issues/improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants